Chapter 3: Sampling the Imaginary

[1]:
%load_ext jupyter_black
[2]:
import random
from typing import Sequence

import jax
import arviz as az
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
import pandas as pd
import plotly
import plotly.graph_objects as go
import plotly.io as pio
from scipy import stats, optimize

pd.options.plotting.backend = "plotly"

seed = 84735
pio.templates.default = "plotly_white"
rng = np.random.default_rng(seed=seed)
jrng = jax.random.PRNGKey(seed)

Code

Code 3.1

[3]:
p_positive_vampire = 0.95
p_positive_mortal = 0.01
p_vampire = 0.001
p_positive = p_positive_vampire * p_vampire + p_positive_mortal * (1 - p_vampire)
p_vampire_positive = p_positive_vampire * p_vampire / p_positive
p_vampire_positive
[3]:
0.08683729433272395

Code 3.2

[4]:
def calculate_posterior_numpyro(W: int, L: int, prior: Sequence[float], grid_size: int):
    grid = jnp.linspace(0, 1, grid_size)
    likelihood = jnp.exp(dist.Binomial(total_count=W + L, probs=grid).log_prob(W))
    raw_posterior = prior * likelihood
    posterior = raw_posterior / raw_posterior.sum()
    return posterior


def calculate_posterior(W: int, L: int, prior: Sequence[float], grid_size: int):
    p_grid = jnp.linspace(0, 1, grid_size)
    likelihood = stats.binom.pmf(k=W, n=W + L, p=p_grid)
    raw_posterior = prior * likelihood
    posterior = raw_posterior / raw_posterior.sum()
    return posterior


W = 6
L = 3
grid_size = 1_000
prior = jnp.full(grid_size, 1)
p_grid = jnp.linspace(0, 1, grid_size)
posterior = calculate_posterior(W, L, prior, grid_size)

Code 3.3

[5]:
samples_numpyro = p_grid[dist.Categorical(probs=posterior).sample(jrng, (10_000,))]
[6]:
samples = rng.choice(p_grid, size=10_000, replace=True, p=posterior)

Code 3.4

[7]:
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=jnp.arange(10_000),
        y=samples,
        mode="markers",
        line={"color": "rgba(0, 0, 255, 0.2)"},
    )
)

Code 3.5

[8]:
az.plot_density({"": samples}, hdi_prob=1)
[8]:
array([[<Axes: >]], dtype=object)
../_images/notebooks_03_sampling_the_imaginary_13_1.png

Code 3.6

[9]:
posterior[p_grid < 0.5].sum()
[9]:
Array(0.17187458, dtype=float32)

Code 3.7

[10]:
jnp.sum(samples < 0.5) / samples.shape[0]
[10]:
Array(0.1675, dtype=float32)

Code 3.8

[11]:
jnp.sum(jnp.logical_and(samples > 0.5, samples < 0.75)) / samples.shape[0]
[11]:
Array(0.6135, dtype=float32)

Code 3.9

[12]:
jnp.quantile(samples, 0.8)
[12]:
Array(0.7587587, dtype=float32)

Code 3.10

[13]:
jnp.quantile(samples, jnp.array([0.1, 0.9]))
[13]:
Array([0.45335338, 0.8118118 ], dtype=float32)

Code 3.11

[14]:
posterior = calculate_posterior(W=3, L=0, prior=jnp.full(1_000, 1), grid_size=1_000)
samples_skewed = p_grid[dist.Categorical(probs=posterior).sample(jrng, (10_000,))]

Code 3.12

[15]:
def percentile_interval(samples, prob):
    prob = min(prob, 1 - prob)
    return jnp.quantile(samples, jnp.array([prob / 2, 1 - prob / 2]))


percentile_interval(samples_skewed, 0.5)
[15]:
Array([0.7067067, 0.9319319], dtype=float32)

Code 3.13

[16]:
numpyro.diagnostics.hpdi(samples_skewed, prob=0.5)
[16]:
array([0.8398398, 0.998999 ], dtype=float32)
[17]:
az.hdi(np.array(samples_skewed), hdi_prob=0.5)
[17]:
array([0.8398398, 0.998999 ], dtype=float32)

Code 3.14

[18]:
p_grid[jnp.argmax(posterior)]
[18]:
Array(1., dtype=float32)

Code 3.15

[19]:
samples_skewed[
    jnp.argmax(stats.gaussian_kde(samples_skewed, bw_method=0.01)(samples_skewed))
]
[19]:
Array(0.985986, dtype=float32)

Code 3.16

[20]:
display(samples_skewed.mean())
jnp.median(samples_skewed)
Array(0.8006291, dtype=float32)
[20]:
Array(0.8408408, dtype=float32)

Code 3.17

[21]:
jnp.sum(jnp.abs(0.5 - p_grid) * posterior)
[21]:
Array(0.31287518, dtype=float32)

Code 3.18

[22]:
loss = jax.vmap(lambda d: jnp.sum(jnp.abs(d - p_grid) * posterior))(p_grid)
fig = go.Figure(data=go.Scatter(x=p_grid, y=loss))
fig.update_layout(
    xaxis={"title": "parameter"},
    yaxis={"title": "expected loss"},
)
fig.show()

Code 3.19

[23]:
p_grid[jnp.argmin(loss)]
[23]:
Array(0.8408408, dtype=float32)

Code 3.20

[24]:
jnp.exp(dist.Binomial(total_count=2, probs=0.7).log_prob(jnp.arange(3)))
[24]:
Array([0.08999996, 0.42000008, 0.48999974], dtype=float32)
[25]:
stats.binom.pmf(k=jnp.arange(3), p=0.7, n=2)
[25]:
array([0.09, 0.42, 0.49])

Code 3.21

[26]:
with numpyro.handlers.seed(rng_seed=seed):
    dummy_w = numpyro.sample("dummy_w", dist.Binomial(total_count=2, probs=0.7))
dummy_w
[26]:
Array(2, dtype=int32, weak_type=True)
[27]:
stats.binom.rvs(n=2, p=0.7, random_state=rng)
[27]:
2

Code 3.22

[28]:
with numpyro.handlers.seed(rng_seed=seed):
    dummy_w = numpyro.sample(
        "dummy_w", dist.Binomial(total_count=2, probs=0.7), sample_shape=(10,)
    )
dummy_w
[28]:
Array([0, 1, 2, 1, 1, 2, 1, 2, 1, 2], dtype=int32, weak_type=True)
[29]:
stats.binom.rvs(n=2, p=0.7, size=(10,), random_state=rng)
[29]:
array([1, 2, 2, 1, 2, 2, 2, 1, 2, 1])

Code 3.23

[30]:
with numpyro.handlers.seed(rng_seed=seed):
    dummy_w = numpyro.sample(
        "dummy_w", dist.Binomial(total_count=2, probs=0.7), sample_shape=(100_000,)
    )
dummy_w = pd.DataFrame(dummy_w, columns=["dummy_w"])
dummy_w["freq"] = 1
dummy_w.groupby("dummy_w").sum() / 100_000
[30]:
freq
dummy_w
0 0.09004
1 0.42109
2 0.48887
[31]:
dummy_w = stats.binom.rvs(n=2, p=0.7, size=(10_000,), random_state=rng)
dummy_w = pd.DataFrame(dummy_w, columns=["dummy_w"])
dummy_w["freq"] = 1
dummy_w.groupby("dummy_w").sum() / 100_000
[31]:
freq
dummy_w
0 0.00923
1 0.04194
2 0.04883

Code 3.24

[32]:
with numpyro.handlers.seed(rng_seed=seed):
    dummy_w = numpyro.sample(
        "dummy_w", dist.Binomial(total_count=9, probs=0.7), sample_shape=(100_000,)
    )
dummy_w = pd.DataFrame(dummy_w, columns=["dummy_w"])
dummy_w.plot(kind="hist")
[33]:
dummy_w = stats.binom.rvs(n=9, p=0.7, size=(100_000,), random_state=rng)
dummy_w = pd.DataFrame(dummy_w, columns=["dummy_w"])
dummy_w.plot(kind="hist")

Code 3.25

[34]:
w = dist.Binomial(total_count=9, probs=0.6).sample(jrng, (10_000,))
pd.DataFrame(w).plot(kind="hist")
[35]:
w = stats.binom.rvs(n=9, p=0.6, size=(100_000,), random_state=rng)
pd.DataFrame(w).plot(kind="hist")

Code 3.26

[36]:
w = dist.Binomial(total_count=9, probs=samples).sample(jrng)
pd.DataFrame(w).plot(kind="hist")
[37]:
w = stats.binom.rvs(n=9, p=samples, random_state=rng)
pd.DataFrame(w).plot(kind="hist")

Easy

[38]:
def calculate_posterior(W: int, L: int, grid_size: int):
    p_grid = jnp.linspace(0, 1, grid_size)
    prior = jnp.array([1 / grid_size] * grid_size)
    likelihood = jnp.exp(dist.Binomial(total_count=W + L, probs=p_grid).log_prob(W))
    raw_posterior = prior * likelihood
    posterior = raw_posterior / raw_posterior.sum()
    samples = p_grid[dist.Categorical(probs=posterior).sample(jrng, (10_000,))]
    return {"pdf": posterior, "samples": samples}


grid_size = 1_000
posterior = calculate_posterior(W=6, L=3, grid_size=grid_size)
pdf = posterior["pdf"]
samples = posterior["samples"]

3E1

[39]:
jnp.sum(samples < 0.2) / samples.shape[0]
[39]:
Array(0.0008, dtype=float32)
[40]:
p_grid = jnp.linspace(0, 1, grid_size)
pdf[p_grid < 0.2].sum()
[40]:
Array(0.0008561, dtype=float32)

3E2

[41]:
jnp.sum(samples > 0.8) / samples.shape[0]
[41]:
Array(0.1239, dtype=float32)
[42]:
pdf[p_grid > 0.8].sum()
[42]:
Array(0.12034493, dtype=float32)

3E3

[43]:
jnp.sum(jnp.logical_and(0.2 < samples, samples < 0.8)) / samples.shape[0]
[43]:
Array(0.8753, dtype=float32)
[44]:
pdf[np.logical_and(0.2 < p_grid, p_grid < 0.8)].sum()
[44]:
Array(0.87879896, dtype=float32)

3E4

[45]:
jnp.quantile(samples, 0.2)
[45]:
Array(0.51631635, dtype=float32)
[46]:
p_grid[jnp.searchsorted(pdf.cumsum(), 0.2)]
[46]:
Array(0.5165165, dtype=float32)

3E5

[47]:
jnp.quantile(samples, 0.8)
[47]:
Array(0.7617617, dtype=float32)
[48]:
p_grid[jnp.searchsorted(pdf.cumsum(), 0.8)]
[48]:
Array(0.7607607, dtype=float32)

3E6

[49]:
numpyro.diagnostics.hpdi(samples, prob=0.66)
[49]:
array([0.5095095, 0.7837838], dtype=float32)

3E7

[50]:
jnp.quantile(samples, jnp.array([(1 - 0.66) / 2, 1 - (1 - 0.66) / 2]))
[50]:
Array([0.49732742, 0.7747748 ], dtype=float32)

Medium

3M1

[51]:
def calculate_posterior(W: int, L: int, grid_size: int, n_samples):
    p_grid = jnp.linspace(0, 1, grid_size)
    prior = jnp.array([1 / grid_size] * grid_size)
    likelihood = jnp.exp(dist.Binomial(total_count=W + L, probs=p_grid).log_prob(W))
    raw_posterior = prior * likelihood
    posterior = raw_posterior / raw_posterior.sum()
    samples = p_grid[dist.Categorical(probs=posterior).sample(jrng, (n_samples,))]
    return {"pdf": posterior, "samples": samples}


grid_size = 1_000
n_samples = 10_000
posterior = calculate_posterior(W=8, L=7, grid_size=grid_size, n_samples=n_samples)
pdf = posterior["pdf"]
samples = posterior["samples"]
pd.DataFrame(
    pdf, index=pd.Index(jnp.linspace(0, 1, grid_size), name="posterior proba of W")
).plot()

3M2

[52]:
numpyro.diagnostics.hpdi(samples, prob=0.9)
[52]:
array([0.33333334, 0.7177177 ], dtype=float32)

3M3

[53]:
with numpyro.handlers.seed(rng_seed=seed):
    n_water = numpyro.sample("n_water", dist.Binomial(total_count=15, probs=samples))
p_8_w_7_l = jnp.sum(n_water == 8) / n_water.shape[0]
print(f"Probabilty 8/15 water is {p_8_w_7_l:.2%}")
pd.DataFrame(n_water).plot(kind="hist")
Probabilty 8/15 water is 15.05%

3M4

[54]:
with numpyro.handlers.seed(rng_seed=seed):
    n_water = numpyro.sample("n_water", dist.Binomial(total_count=9, probs=samples))
p_6_w_3_l = jnp.sum(n_water == 6) / n_water.shape[0]
print(f"Probabilty 6/9 water is {p_6_w_3_l:.2%}")
pd.DataFrame(n_water).plot(kind="hist")
Probabilty 6/9 water is 17.71%

3M5

[55]:
def calculate_posterior(W: int, L: int, grid_size: int, n_samples):
    p_grid = jnp.linspace(0, 1, grid_size)
    prior = jnp.array([0 if p < 0.5 else 1 for p in p_grid])
    prior /= prior.sum()
    likelihood = jnp.exp(dist.Binomial(total_count=W + L, probs=p_grid).log_prob(W))
    raw_posterior = prior * likelihood
    posterior = raw_posterior / raw_posterior.sum()
    samples = p_grid[dist.Categorical(probs=posterior).sample(jrng, (n_samples,))]
    return {"pdf": posterior, "samples": samples}


grid_size = 1_000
n_samples = 10_000
posterior = calculate_posterior(W=8, L=7, grid_size=grid_size, n_samples=n_samples)
pdf = posterior["pdf"]
samples = posterior["samples"]
pd.DataFrame(
    pdf, index=pd.Index(jnp.linspace(0, 1, grid_size), name="posterior proba of W")
).plot().show()
az.plot_dist(samples)
[55]:
<Axes: >
../_images/notebooks_03_sampling_the_imaginary_95_2.png
[56]:
n_water = stats.binom.rvs(n=15, p=samples, random_state=rng)
p_8_w_7_l = jnp.sum(n_water == 8) / n_water.shape[0]
print(f"Probabilty 8/15 water is {p_8_w_7_l:.2%}")
pd.DataFrame(n_water).plot(kind="hist")
Probabilty 8/15 water is 16.11%

3M6

[57]:
def pi_width(n_tosses, n_posterior_samples, n_experiments):
    p_grid = jnp.linspace(0, 1, grid_size)
    prior = jnp.array([1 / grid_size] * grid_size)

    width = [jnp.nan] * n_experiments

    n_water = stats.binom.rvs(
        n=n_tosses, p=0.7, random_state=rng, size=(n_experiments,)
    )
    likelihood = jnp.exp(
        jnp.array(
            [
                dist.Binomial(total_count=n_tosses, probs=p_grid).log_prob(_n_water)
                for _n_water in n_water
            ]
        )
    )
    raw_posterior = prior * likelihood
    posterior = raw_posterior / raw_posterior.sum()
    samples = jnp.array(
        [
            p_grid[
                dist.Categorical(probs=_posterior).sample(jrng, (n_posterior_samples,))
            ]
            for _posterior in posterior
        ]
    )
    quantiles = jnp.quantile(samples, jnp.array([0.005, 0.995]), axis=1)
    return float((quantiles[1] - quantiles[0]).mean())


def objective(n_tosses):
    return pi_width(int(n_tosses), n_posterior_samples=10_000, n_experiments=100) - 0.05
[58]:
n_tosses = list(range(500, 5_500, 500))
widths = [
    pi_width(_n_tosses, n_posterior_samples=10_000, n_experiments=100)
    for _n_tosses in n_tosses
]
[59]:
fig = go.Figure(data=go.Scatter(x=n_tosses, y=widths))
fig.add_hline(y=0.05, line={"color": "red", "dash": "dash"})
fig.update_layout(
    xaxis={"title": "number of tosses"},
    yaxis={"title": "width of 99% compatibility interval"},
)
fig
[60]:
result = optimize.root_scalar(objective, x0=2_000, bracket=[1_500, 2_500], xtol=1)
result
[60]:
      converged: True
           flag: converged
 function_calls: 9
     iterations: 8
           root: 2189.1438517842357
         method: brentq

3H1

[61]:
# fmt: off
births_1 = [
    1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0,
    0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0,
    0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0,
    0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1,
]
births_2 = [
    0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1,
    1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
    1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0,
    0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
]
births = jnp.array([births_1, births_2])
[62]:
grid_size = 1_000

p_grid = jnp.linspace(0, 1, grid_size)
prior = [1 / grid_size] * grid_size
likelhihood = jnp.exp(
    dist.Binomial(total_count=births.size, probs=p_grid).log_prob(births.sum())
)
raw_posterior = likelhihood * jnp.array(prior)
posterior = raw_posterior / raw_posterior.sum()
map_p = p_grid[jnp.argmax(posterior)]
print(f"p={map_p:.2%} maximizes the posterior probability.")
p=55.46% maximizes the posterior probability.

3H2

[63]:
posterior_samples = p_grid[dist.Categorical(probs=posterior).sample(jrng, (10_000,))]
print(f"50% HDPI: {numpyro.diagnostics.hpdi(posterior_samples, prob=0.5)}")
print(f"89% HDPI: {numpyro.diagnostics.hpdi(posterior_samples, prob=0.89)}")
print(f"97% HDPI: {numpyro.diagnostics.hpdi(posterior_samples, prob=0.97)}")
50% HDPI: [0.5275275 0.5745746]
89% HDPI: [0.4994995 0.6096096]
97% HDPI: [0.47847846 0.6286286 ]

3H3

[64]:
posterior_predictive_samples = dist.Binomial(
    total_count=births.size, probs=posterior_samples
).sample(jrng)
print(
    "Posterior Predictive Check:\n"
    "Samples from posterior predictive distribution have, on average, "
    f"{posterior_predictive_samples.mean():.0f} boys "
    f"vs {births.sum()} in our training data"
)
fig = pd.DataFrame(posterior_predictive_samples, columns=["n_boys"]).plot(kind="hist")
fig.add_vline(x=births.sum(), line={"color": "red"})
Posterior Predictive Check:
Samples from posterior predictive distribution have, on average, 111 boys vs 111 in our training data

3H4

[65]:
posterior_predictive_samples = dist.Binomial(
    total_count=births.shape[1], probs=posterior_samples
).sample(jrng)
print(
    f"Posterior predictive distribution of first born sons has mean {posterior_predictive_samples.mean():.0f} "
    f"vs obersvation of {births[0].sum()}; still reasonable but not as good as purely 'in-sample'."
)
fig = pd.DataFrame(posterior_predictive_samples, columns=["n_first_born_boys"]).plot(
    kind="hist"
)
fig.add_vline(x=births[0, :].sum(), line={"color": "red"})
Posterior predictive distribution of first born sons has mean 55 vs obersvation of 51; still reasonable but not as good as purely 'in-sample'.

3H5

[66]:
posterior_predictive_samples.shape
[66]:
(10000,)
[67]:
posterior_predictive_samples = dist.Binomial(
    total_count=jnp.logical_not(births[0]).sum(), probs=posterior_samples
).sample(jrng)
is_big_sister = births[0, :] == 0
print(
    f"PoPD of boys with big sisters of {posterior_predictive_samples.mean():.0f} "
    f"is completely out of line with observations of {births[1, is_big_sister].sum()}: we didn't model "
    "the correlation between first and second birth that's present in our dataset."
)
fig = pd.DataFrame(
    posterior_predictive_samples, columns=["n_boys_with_big_sister"]
).plot(kind="hist")

fig.add_vline(x=births[1, is_big_sister].sum(), line={"color": "red"})
PoPD of boys with big sisters of 27 is completely out of line with observations of 39: we didn't model the correlation between first and second birth that's present in our dataset.